{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0202a11e",
   "metadata": {},
   "source": [
    "## Install\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3fdd19f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: pip in /root/test/lib/python3.10/site-packages (24.3.1)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "Writing to /root/.config/pip/pip.conf\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: modelscope==1.16.1 in /root/test/lib/python3.10/site-packages (1.16.1)\n",
      "Requirement already satisfied: requests>=2.25 in /root/test/lib/python3.10/site-packages (from modelscope==1.16.1) (2.32.3)\n",
      "Requirement already satisfied: tqdm>=4.64.0 in /root/test/lib/python3.10/site-packages (from modelscope==1.16.1) (4.67.1)\n",
      "Requirement already satisfied: urllib3>=1.26 in /root/test/lib/python3.10/site-packages (from modelscope==1.16.1) (2.2.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/test/lib/python3.10/site-packages (from requests>=2.25->modelscope==1.16.1) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/test/lib/python3.10/site-packages (from requests>=2.25->modelscope==1.16.1) (3.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/test/lib/python3.10/site-packages (from requests>=2.25->modelscope==1.16.1) (2024.12.14)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: transformers in /root/test/lib/python3.10/site-packages (4.47.1)\n",
      "Requirement already satisfied: filelock in /root/test/lib/python3.10/site-packages (from transformers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /root/test/lib/python3.10/site-packages (from transformers) (0.27.0)\n",
      "Requirement already satisfied: numpy>=1.17 in /root/test/lib/python3.10/site-packages (from transformers) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /root/test/lib/python3.10/site-packages (from transformers) (24.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /root/test/lib/python3.10/site-packages (from transformers) (6.0.2)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /root/test/lib/python3.10/site-packages (from transformers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /root/test/lib/python3.10/site-packages (from transformers) (2.32.3)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/test/lib/python3.10/site-packages (from transformers) (0.21.0)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /root/test/lib/python3.10/site-packages (from transformers) (0.4.5)\n",
      "Requirement already satisfied: tqdm>=4.27 in /root/test/lib/python3.10/site-packages (from transformers) (4.67.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /root/test/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (2024.5.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/test/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.24.0->transformers) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/test/lib/python3.10/site-packages (from requests->transformers) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/test/lib/python3.10/site-packages (from requests->transformers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/test/lib/python3.10/site-packages (from requests->transformers) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/test/lib/python3.10/site-packages (from requests->transformers) (2024.12.14)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: accelerate in /root/test/lib/python3.10/site-packages (0.32.1)\n",
      "Collecting accelerate\n",
      "  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/c2/60/a585c806d6c0ec5f8149d44eb202714792802f484e6e2b1bf96b23bd2b00/accelerate-1.2.1-py3-none-any.whl (336 kB)\n",
      "Requirement already satisfied: numpy<3.0.0,>=1.17 in /root/test/lib/python3.10/site-packages (from accelerate) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /root/test/lib/python3.10/site-packages (from accelerate) (24.2)\n",
      "Requirement already satisfied: psutil in /root/test/lib/python3.10/site-packages (from accelerate) (5.9.0)\n",
      "Requirement already satisfied: pyyaml in /root/test/lib/python3.10/site-packages (from accelerate) (6.0.2)\n",
      "Requirement already satisfied: torch>=1.10.0 in /root/test/lib/python3.10/site-packages (from accelerate) (2.5.1+cu118)\n",
      "Requirement already satisfied: huggingface-hub>=0.21.0 in /root/test/lib/python3.10/site-packages (from accelerate) (0.27.0)\n",
      "Requirement already satisfied: safetensors>=0.4.3 in /root/test/lib/python3.10/site-packages (from accelerate) (0.4.5)\n",
      "Requirement already satisfied: filelock in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (3.16.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2024.5.0)\n",
      "Requirement already satisfied: requests in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.67.1)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n",
      "Requirement already satisfied: networkx in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.3)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.8.87 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.8.87)\n",
      "Requirement already satisfied: nvidia-cudnn-cu11==9.1.0.70 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu11==11.11.3.6 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.11.3.6)\n",
      "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.9.0.58)\n",
      "Requirement already satisfied: nvidia-curand-cu11==10.3.0.86 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (10.3.0.86)\n",
      "Requirement already satisfied: nvidia-cusolver-cu11==11.4.1.48 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.4.1.48)\n",
      "Requirement already satisfied: nvidia-cusparse-cu11==11.7.5.86 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.7.5.86)\n",
      "Requirement already satisfied: nvidia-nccl-cu11==2.21.5 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu11==11.8.86 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (11.8.86)\n",
      "Requirement already satisfied: triton==3.1.0 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (3.1.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /root/test/lib/python3.10/site-packages (from torch>=1.10.0->accelerate) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /root/test/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.10.0->accelerate) (1.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /root/test/lib/python3.10/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.12.14)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mInstalling collected packages: accelerate\n",
      "  Attempting uninstall: accelerate\n",
      "    Found existing installation: accelerate 0.32.1\n",
      "    Uninstalling accelerate-0.32.1:\n",
      "      Successfully uninstalled accelerate-0.32.1\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mSuccessfully installed accelerate-1.2.1\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: peft==0.11.1 in /root/test/lib/python3.10/site-packages (0.11.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (24.2)\n",
      "Requirement already satisfied: psutil in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (5.9.0)\n",
      "Requirement already satisfied: pyyaml in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (6.0.2)\n",
      "Requirement already satisfied: torch>=1.13.0 in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (2.5.1+cu118)\n",
      "Requirement already satisfied: transformers in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (4.47.1)\n",
      "Requirement already satisfied: tqdm in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (4.67.1)\n",
      "Requirement already satisfied: accelerate>=0.21.0 in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (1.2.1)\n",
      "Requirement already satisfied: safetensors in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (0.4.5)\n",
      "Requirement already satisfied: huggingface-hub>=0.17.0 in /root/test/lib/python3.10/site-packages (from peft==0.11.1) (0.27.0)\n",
      "Requirement already satisfied: filelock in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft==0.11.1) (3.16.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft==0.11.1) (2024.5.0)\n",
      "Requirement already satisfied: requests in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft==0.11.1) (2.32.3)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.17.0->peft==0.11.1) (4.12.2)\n",
      "Requirement already satisfied: networkx in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (3.1.3)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.8.87 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.8.87)\n",
      "Requirement already satisfied: nvidia-cudnn-cu11==9.1.0.70 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu11==11.11.3.6 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.11.3.6)\n",
      "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (10.9.0.58)\n",
      "Requirement already satisfied: nvidia-curand-cu11==10.3.0.86 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (10.3.0.86)\n",
      "Requirement already satisfied: nvidia-cusolver-cu11==11.4.1.48 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.4.1.48)\n",
      "Requirement already satisfied: nvidia-cusparse-cu11==11.7.5.86 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.7.5.86)\n",
      "Requirement already satisfied: nvidia-nccl-cu11==2.21.5 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu11==11.8.86 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (11.8.86)\n",
      "Requirement already satisfied: triton==3.1.0 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (3.1.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /root/test/lib/python3.10/site-packages (from torch>=1.13.0->peft==0.11.1) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /root/test/lib/python3.10/site-packages (from sympy==1.13.1->torch>=1.13.0->peft==0.11.1) (1.3.0)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /root/test/lib/python3.10/site-packages (from transformers->peft==0.11.1) (2024.11.6)\n",
      "Requirement already satisfied: tokenizers<0.22,>=0.21 in /root/test/lib/python3.10/site-packages (from transformers->peft==0.11.1) (0.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /root/test/lib/python3.10/site-packages (from jinja2->torch>=1.13.0->peft==0.11.1) (2.1.5)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft==0.11.1) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft==0.11.1) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft==0.11.1) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/test/lib/python3.10/site-packages (from requests->huggingface-hub>=0.17.0->peft==0.11.1) (2024.12.14)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: datasets==2.20.0 in /root/test/lib/python3.10/site-packages (2.20.0)\n",
      "Requirement already satisfied: filelock in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (3.16.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (1.26.4)\n",
      "Requirement already satisfied: pyarrow>=15.0.0 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (18.1.0)\n",
      "Requirement already satisfied: pyarrow-hotfix in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (0.6)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (0.3.8)\n",
      "Requirement already satisfied: pandas in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (2.2.3)\n",
      "Requirement already satisfied: requests>=2.32.2 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (2.32.3)\n",
      "Requirement already satisfied: tqdm>=4.66.3 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (4.67.1)\n",
      "Requirement already satisfied: xxhash in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (3.5.0)\n",
      "Requirement already satisfied: multiprocess in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (0.70.16)\n",
      "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /root/test/lib/python3.10/site-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets==2.20.0) (2024.5.0)\n",
      "Requirement already satisfied: aiohttp in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (3.11.11)\n",
      "Requirement already satisfied: huggingface-hub>=0.21.2 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (0.27.0)\n",
      "Requirement already satisfied: packaging in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (24.2)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /root/test/lib/python3.10/site-packages (from datasets==2.20.0) (6.0.2)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (2.4.4)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (1.3.2)\n",
      "Requirement already satisfied: async-timeout<6.0,>=4.0 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (5.0.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (24.3.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (1.5.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (6.1.0)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (0.2.1)\n",
      "Requirement already satisfied: yarl<2.0,>=1.17.0 in /root/test/lib/python3.10/site-packages (from aiohttp->datasets==2.20.0) (1.18.3)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /root/test/lib/python3.10/site-packages (from huggingface-hub>=0.21.2->datasets==2.20.0) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/test/lib/python3.10/site-packages (from requests>=2.32.2->datasets==2.20.0) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/test/lib/python3.10/site-packages (from requests>=2.32.2->datasets==2.20.0) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/test/lib/python3.10/site-packages (from requests>=2.32.2->datasets==2.20.0) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/test/lib/python3.10/site-packages (from requests>=2.32.2->datasets==2.20.0) (2024.12.14)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /root/test/lib/python3.10/site-packages (from pandas->datasets==2.20.0) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /root/test/lib/python3.10/site-packages (from pandas->datasets==2.20.0) (2024.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /root/test/lib/python3.10/site-packages (from pandas->datasets==2.20.0) (2024.2)\n",
      "Requirement already satisfied: six>=1.5 in /root/test/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.20.0) (1.17.0)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0mLooking in indexes: https://download.pytorch.org/whl/cu118\n",
      "Requirement already satisfied: torch in /root/test/lib/python3.10/site-packages (2.5.1+cu118)\n",
      "Requirement already satisfied: torchvision in /root/test/lib/python3.10/site-packages (0.20.1+cu118)\n",
      "Requirement already satisfied: torchaudio in /root/test/lib/python3.10/site-packages (2.5.1+cu118)\n",
      "Requirement already satisfied: filelock in /root/test/lib/python3.10/site-packages (from torch) (3.16.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /root/test/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: networkx in /root/test/lib/python3.10/site-packages (from torch) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /root/test/lib/python3.10/site-packages (from torch) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /root/test/lib/python3.10/site-packages (from torch) (2024.5.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.8.89 in /root/test/lib/python3.10/site-packages (from torch) (11.8.89)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.8.87 in /root/test/lib/python3.10/site-packages (from torch) (11.8.87)\n",
      "Requirement already satisfied: nvidia-cudnn-cu11==9.1.0.70 in /root/test/lib/python3.10/site-packages (from torch) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu11==11.11.3.6 in /root/test/lib/python3.10/site-packages (from torch) (11.11.3.6)\n",
      "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /root/test/lib/python3.10/site-packages (from torch) (10.9.0.58)\n",
      "Requirement already satisfied: nvidia-curand-cu11==10.3.0.86 in /root/test/lib/python3.10/site-packages (from torch) (10.3.0.86)\n",
      "Requirement already satisfied: nvidia-cusolver-cu11==11.4.1.48 in /root/test/lib/python3.10/site-packages (from torch) (11.4.1.48)\n",
      "Requirement already satisfied: nvidia-cusparse-cu11==11.7.5.86 in /root/test/lib/python3.10/site-packages (from torch) (11.7.5.86)\n",
      "Requirement already satisfied: nvidia-nccl-cu11==2.21.5 in /root/test/lib/python3.10/site-packages (from torch) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu11==11.8.86 in /root/test/lib/python3.10/site-packages (from torch) (11.8.86)\n",
      "Requirement already satisfied: triton==3.1.0 in /root/test/lib/python3.10/site-packages (from torch) (3.1.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /root/test/lib/python3.10/site-packages (from torch) (1.13.1)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /root/test/lib/python3.10/site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
      "Requirement already satisfied: numpy in /root/test/lib/python3.10/site-packages (from torchvision) (1.26.4)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /root/test/lib/python3.10/site-packages (from torchvision) (10.2.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /root/test/lib/python3.10/site-packages (from jinja2->torch) (2.1.5)\n",
      "\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution -odelscope (/root/test/lib/python3.10/site-packages)\u001b[0m\u001b[33m\n",
      "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
      "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install --upgrade pip\n",
    "!pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple\n",
    "%pip install modelscope==1.16.1\n",
    "%pip install transformers --upgrade\n",
    "%pip install accelerate --upgrade\n",
    "%pip install peft==0.11.1\n",
    "%pip install datasets==2.20.0\n",
    "%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "52fac949-4150-4091-b0c3-2968ab5e385c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/test/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from datasets import Dataset\n",
    "import pandas as pd\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e098d9eb",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'instruction': ['祝赵老师生日,文艺风格', '祝赵老师春节,文艺风格', '祝赵老师元宵节,文艺风格'],\n",
       " 'input': ['', '', ''],\n",
       " 'output': ['赵老师，岁月在您的身上留下了温柔的痕迹，如同五月的晨光，不愠不火，却总能温暖人心。记得第一次见您时，您正坐在窗边批改作业，阳光透过树叶的缝隙，洒在您的身上，那一刻，我仿佛看到了知识与智慧的光芒。您的每一堂课，都如木槿的清香，让人沉醉。在您的生日这天，希望时间能为您停留，让每一天都充满温馨与美好。愿您依然保持那份纯真的笑容，继续为我们点亮前行的道路。在未来的日子里，无论风雨，我们都将铭记您的教诲，勇敢地追逐梦想。感谢您，不仅教会我们知识，更教会我们如何成为更好的人。',\n",
       "  '赵老师，春节之际，我总想起那些在校园里度过的时光。您曾像一盏明灯，照亮我们前行的道路，用智慧的光芒驱散未知的迷雾。记得那个冬天，您手捧热茶，耐心为我们解答难题，窗外的雪静静飘落，那一刻，时间仿佛凝固，只留下知识的种子在心底悄悄发芽。您的言传身教，如同春风拂过山野，让每一颗心都感受到了温暖与希望。在这辞旧迎新的时刻，祝愿您在新的一年里，依然能够感受到生活的美好，家庭幸福，身体健康，事业顺利。愿岁月温柔以待，愿您心中的那份热爱永远不减，继续在知识的海洋里泛舟，带领更多的人寻找到属于自己的光。',\n",
       "  '在这一年的元宵佳节，我不禁回想起与您相处的点点滴滴。那些关于知识的探讨、关于生活的感悟，都如同一盏盏明灯，照亮了我前行的道路。您的话语总能像春日里的暖阳，驱散我心中的阴霾。感谢您用智慧和耐心，为我的成长播下希望的种子。在这个团圆的时刻，祝您赵老师身体健康，幸福安康，生活如这元宵节的灯笼般，明亮而温暖。']}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将JSON文件转换为CSV文件\n",
    "df = pd.read_json('/home/merged.json')\n",
    "ds = Dataset.from_pandas(df)\n",
    "ds[:3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45787138",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modelscope import snapshot_download, AutoModel, AutoTokenizer\n",
    "model_dir = snapshot_download('qwen/Qwen2.5-7B-Instruct', cache_dir='/home/temp', revision='master')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab4fc313",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_PATH = \"/home/temp/qwen/Qwen2___5-7B-Instruct\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51d05e5d-d14e-4f03-92be-9a9677d41918",
   "metadata": {},
   "source": [
    "# 处理数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "74ee5a67-2e55-4974-b90e-cbf492de500a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Qwen2Tokenizer(name_or_path='/home/temp/qwen/Qwen2___5-7B-Instruct', vocab_size=151643, model_max_length=131072, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<|im_end|>', 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|im_start|>', '<|im_end|>', '<|object_ref_start|>', '<|object_ref_end|>', '<|box_start|>', '<|box_end|>', '<|quad_start|>', '<|quad_end|>', '<|vision_start|>', '<|vision_end|>', '<|vision_pad|>', '<|image_pad|>', '<|video_pad|>']}, clean_up_tokenization_spaces=False, added_tokens_decoder={\n",
       "\t151643: AddedToken(\"<|endoftext|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151644: AddedToken(\"<|im_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151645: AddedToken(\"<|im_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151646: AddedToken(\"<|object_ref_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151647: AddedToken(\"<|object_ref_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151648: AddedToken(\"<|box_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151649: AddedToken(\"<|box_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151650: AddedToken(\"<|quad_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151651: AddedToken(\"<|quad_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151652: AddedToken(\"<|vision_start|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151653: AddedToken(\"<|vision_end|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151654: AddedToken(\"<|vision_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151655: AddedToken(\"<|image_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151656: AddedToken(\"<|video_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
       "\t151657: AddedToken(\"<tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151658: AddedToken(\"</tool_call>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151659: AddedToken(\"<|fim_prefix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151660: AddedToken(\"<|fim_middle|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151661: AddedToken(\"<|fim_suffix|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151662: AddedToken(\"<|fim_pad|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151663: AddedToken(\"<|repo_name|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "\t151664: AddedToken(\"<|file_sep|>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False),\n",
       "}\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True)\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2503a5fa-9621-4495-9035-8e7ef6525691",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    MAX_LENGTH = 1024    # Llama分词器会将一个中文字切分为多个token，因此需要放开一些最大长度，保证数据的完整性\n",
    "    input_ids, attention_mask, labels = [], [], []\n",
    "    instruction = tokenizer(f\"<|im_start|>system\\n你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福<|im_end|>\\n<|im_start|>user\\n{example['instruction'] + example['input']}<|im_end|>\\n<|im_start|>assistant\\n\", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens\n",
    "    response = tokenizer(f\"{example['output']}\", add_special_tokens=False)\n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1\n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]  \n",
    "    if len(input_ids) > MAX_LENGTH:  # 做一个截断\n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"labels\": labels\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "84f870d6-73a9-4b0f-8abf-687b32224ad8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 3276/3276 [00:14<00:00, 220.23 examples/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 3276\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
    "tokenized_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1f7e15a0-4d9a-4935-9861-00cc472654b1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<|im_start|>system\\n你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福<|im_end|>\\n<|im_start|>user\\n祝赵老师生日,文艺风格<|im_end|>\\n<|im_start|>assistant\\n赵老师，岁月在您的身上留下了温柔的痕迹，如同五月的晨光，不愠不火，却总能温暖人心。记得第一次见您时，您正坐在窗边批改作业，阳光透过树叶的缝隙，洒在您的身上，那一刻，我仿佛看到了知识与智慧的光芒。您的每一堂课，都如木槿的清香，让人沉醉。在您的生日这天，希望时间能为您停留，让每一天都充满温馨与美好。愿您依然保持那份纯真的笑容，继续为我们点亮前行的道路。在未来的日子里，无论风雨，我们都将铭记您的教诲，勇敢地追逐梦想。感谢您，不仅教会我们知识，更教会我们如何成为更好的人。<|endoftext|>'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_id[0]['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "97f16f66-324a-454f-8cc3-ef23b100ecff",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'赵老师，春节之际，我总想起那些在校园里度过的时光。您曾像一盏明灯，照亮我们前行的道路，用智慧的光芒驱散未知的迷雾。记得那个冬天，您手捧热茶，耐心为我们解答难题，窗外的雪静静飘落，那一刻，时间仿佛凝固，只留下知识的种子在心底悄悄发芽。您的言传身教，如同春风拂过山野，让每一颗心都感受到了温暖与希望。在这辞旧迎新的时刻，祝愿您在新的一年里，依然能够感受到生活的美好，家庭幸福，身体健康，事业顺利。愿岁月温柔以待，愿您心中的那份热爱永远不减，继续在知识的海洋里泛舟，带领更多的人寻找到属于自己的光。<|endoftext|>'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "424823a8-ed0d-4309-83c8-3f6b1cdf274c",
   "metadata": {},
   "source": [
    "# 创建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "170764e5-d899-4ef4-8c53-36f6dec0d198",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.11s/it]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Qwen2ForCausalLM(\n",
       "  (model): Qwen2Model(\n",
       "    (embed_tokens): Embedding(152064, 3584)\n",
       "    (layers): ModuleList(\n",
       "      (0-27): 28 x Qwen2DecoderLayer(\n",
       "        (self_attn): Qwen2SdpaAttention(\n",
       "          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)\n",
       "          (k_proj): Linear(in_features=3584, out_features=512, bias=True)\n",
       "          (v_proj): Linear(in_features=3584, out_features=512, bias=True)\n",
       "          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)\n",
       "          (rotary_emb): Qwen2RotaryEmbedding()\n",
       "        )\n",
       "        (mlp): Qwen2MLP(\n",
       "          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)\n",
       "          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)\n",
       "          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)\n",
       "        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)\n",
       "      )\n",
       "    )\n",
       "    (norm): Qwen2RMSNorm((3584,), eps=1e-06)\n",
       "    (rotary_emb): Qwen2RotaryEmbedding()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=3584, out_features=152064, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map=\"auto\",torch_dtype=torch.bfloat16)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2323eac7-37d5-4288-8bc5-79fac7113402",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.enable_input_require_grads() # 开启梯度检查点时，要执行该方法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f808b05c-f2cb-48cf-a80d-0c42be6051c7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.bfloat16"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.dtype"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13d71257-3c1c-4303-8ff8-af161ebc2cf1",
   "metadata": {},
   "source": [
    "# lora "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2d304ae2-ab60-4080-a80d-19cac2e3ade3",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=32, target_modules={'up_proj', 'gate_proj', 'q_proj', 'v_proj', 'o_proj', 'k_proj', 'down_proj'}, lora_alpha=16, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
    "    inference_mode=False, # 训练模式\n",
    "    r=32, # Lora 秩\n",
    "    lora_alpha=16, # Lora alaph，具体作用参见 Lora 原理\n",
    "    lora_dropout=0.1# Dropout 比例\n",
    ")\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2c2489c5-eaab-4e1f-b06a-c3f914b4bf8e",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path='/home/temp/qwen/Qwen2___5-7B-Instruct', revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=32, target_modules={'up_proj', 'gate_proj', 'q_proj', 'v_proj', 'o_proj', 'k_proj', 'down_proj'}, lora_alpha=16, lora_dropout=0.1, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = get_peft_model(model, config)\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ebf5482b-fab9-4eb3-ad88-c116def4be12",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 80,740,352 || all params: 7,696,356,864 || trainable%: 1.0491\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca055683-837f-4865-9c57-9164ba60c00f",
   "metadata": {},
   "source": [
    "# 配置训练参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7e76bbff-15fd-4995-a61d-8364dc5e9ea0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "args = TrainingArguments(\n",
    "    output_dir=\"/home/output/Qwen2.5_instruct_lora2\",\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=10,\n",
    "    num_train_epochs=3,\n",
    "    save_steps=100, \n",
    "    learning_rate=1e-4,\n",
    "    save_on_each_node=True,\n",
    "    gradient_checkpointing=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f142cb9c-ad99-48e6-ba86-6df198f9ed96",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=tokenized_id,\n",
    "    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "aec9bc36-b297-45af-99e1-d4c4d82be081",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='612' max='612' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [612/612 22:36, Epoch 2/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>2.207000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.915000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>1.943400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>1.814300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>1.701800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>1.705800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>1.588000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>1.519700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>1.420500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.461100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>1.435700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>1.353400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>1.440900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>1.370100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>1.275300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>1.183700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>1.322600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>1.349900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>1.354600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.401500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>1.231300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>1.110300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>1.206900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>1.081400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>1.093600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>260</td>\n",
       "      <td>1.115800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>270</td>\n",
       "      <td>1.241300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>280</td>\n",
       "      <td>1.093800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>290</td>\n",
       "      <td>1.217400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.229100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>310</td>\n",
       "      <td>1.151100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>320</td>\n",
       "      <td>1.051800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>330</td>\n",
       "      <td>1.187400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>340</td>\n",
       "      <td>1.090600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>1.229000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>360</td>\n",
       "      <td>1.282300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>370</td>\n",
       "      <td>1.125300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>380</td>\n",
       "      <td>1.175100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>390</td>\n",
       "      <td>1.154200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>1.204000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>410</td>\n",
       "      <td>1.112600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>420</td>\n",
       "      <td>1.082500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>430</td>\n",
       "      <td>1.134900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>440</td>\n",
       "      <td>1.063300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>450</td>\n",
       "      <td>1.106900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>460</td>\n",
       "      <td>1.014200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>470</td>\n",
       "      <td>0.970300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>480</td>\n",
       "      <td>1.149000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>490</td>\n",
       "      <td>1.106800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>1.087100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>510</td>\n",
       "      <td>0.981600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>520</td>\n",
       "      <td>1.045300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>530</td>\n",
       "      <td>1.120000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>540</td>\n",
       "      <td>1.093100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>550</td>\n",
       "      <td>1.093200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>560</td>\n",
       "      <td>1.018100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>570</td>\n",
       "      <td>1.084700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>580</td>\n",
       "      <td>1.018200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>590</td>\n",
       "      <td>1.145000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.953200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>610</td>\n",
       "      <td>1.039200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=612, training_loss=1.2518605264573315, metrics={'train_runtime': 1359.1261, 'train_samples_per_second': 7.231, 'train_steps_per_second': 0.45, 'total_flos': 1.8222733818803405e+17, 'train_loss': 1.2518605264573315, 'epoch': 2.9865689865689866})"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8abb2327-458e-4e96-ac98-2141b5b97c8e",
   "metadata": {},
   "source": [
    "# 合并加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "bd2a415a-a9ad-49ea-877f-243558a83bfc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.42s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "我是你的智能助手，随时准备为你送上最温馨的祝福。无论是生日、节日还是其他特殊时刻，我都能为你定制专属的祝福语。\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "mode_path = MODEL_PATH\n",
    "lora_path = '/home/output/Qwen2.5_instruct_lora2/checkpoint-612' # 这里改称你的 lora 输出对应 checkpoint 地址\n",
    "\n",
    "# 加载tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
    "\n",
    "# 加载模型\n",
    "model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
    "\n",
    "# 加载lora权重\n",
    "model = PeftModel.from_pretrained(model, model_id=lora_path)\n",
    "\n",
    "prompt = \"你是谁？\"\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"system\", \"content\": \"你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福\"},{\"role\": \"user\", \"content\": prompt}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       ).to('cuda')\n",
    "\n",
    "\n",
    "gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1, \"temperature\":0.8}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "325a5512",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "同学！腊八节快乐呀！🎉 你家的腊八粥今年又熬得怎么样了？上次你发的照片，那香味儿我都快流口水了！不过你能不能别再偷偷给我寄腊八蒜了？我同事都以为我有亲戚在你家呢！😂 还有啊，你那腊八蒜泡得那么久，是不是想让我也变成个老腊八蒜？不过话说回来，你这手艺真是一绝，下次可得给我多带点！对了，你不是说要研究腊八粥的新配方吗？别忘了邀请我一起尝尝哦！最后，记得腊八节那天一定要多吃点，别光顾着工作，身体才是革命的本钱嘛！😊\n"
     ]
    }
   ],
   "source": [
    "prompt = \"祝同学腊八节,放飞自我风格\"\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"system\", \"content\": \"你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福\"},{\"role\": \"user\", \"content\": prompt}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       ).to('cuda')\n",
    "\n",
    "\n",
    "gen_kwargs = {\"max_length\": 4096, \"do_sample\": True, \"temperature\":0.6}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d7a1184c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/test/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/temp/qwen/Qwen2___5-7B-Instruct - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('/home/tianji-wishes-v2.0/tokenizer_config.json',\n",
       " '/home/tianji-wishes-v2.0/special_tokens_map.json',\n",
       " '/home/tianji-wishes-v2.0/vocab.json',\n",
       " '/home/tianji-wishes-v2.0/merges.txt',\n",
       " '/home/tianji-wishes-v2.0/added_tokens.json',\n",
       " '/home/tianji-wishes-v2.0/tokenizer.json')"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 保存模型和分词器\n",
    "save_dir = \"/home/tianji-wishes-v2.0\"\n",
    "model.save_pretrained(save_dir)\n",
    "tokenizer.save_pretrained(save_dir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "3d2c6e0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.merge_and_unload()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "3614f11c",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "db243e6c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('/home/tianji-wishes-v2.0/tokenizer_config.json',\n",
       " '/home/tianji-wishes-v2.0/special_tokens_map.json',\n",
       " '/home/tianji-wishes-v2.0/vocab.json',\n",
       " '/home/tianji-wishes-v2.0/merges.txt',\n",
       " '/home/tianji-wishes-v2.0/added_tokens.json',\n",
       " '/home/tianji-wishes-v2.0/tokenizer.json')"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(save_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ae458c71",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "同学！腊八节快乐呀！🎉 你家的腊八粥是不是又香又甜？上次你给我带的那碗，我到现在都念念不忘呢！你那手艺简直了，我都快爱上腊八粥了！不过，你能不能别再偷偷往里面加辣椒了？上次我喝了一大口，辣得我眼泪都出来了，结果还得你帮我解辣...😂\n",
      "\n",
      "对了，听说你最近又在研究什么新菜谱？别忘了给我留一份哦！你那手艺，我可是期待已久啦！还有啊，你能不能别再天天发朋友圈晒美食了？我都快被你的厨艺迷住了！不过，你那个“腊八蒜”能不能也别放那么多醋？上次我吃了一个，酸得我直皱眉...😋\n",
      "\n",
      "最后，记得腊八节那天多煮点粥，别忘了叫我一声哦！我可是你的忠实粉丝，你做的每一道菜我都在心里记着呢！哈哈，祝你腊八节快乐，身体健康，厨艺更上一层楼！🧧\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "mode_path = '/home/tianji-wish2-7b' # 这里改称你的 lora 输出对应 checkpoint 地址\n",
    "\n",
    "# 加载tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
    "\n",
    "# 加载模型\n",
    "model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
    "\n",
    "prompt = \"祝同学腊八节,放飞自我风格\"\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"system\", \"content\": \"你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福\"},{\"role\": \"user\", \"content\": prompt}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       ).to('cuda')\n",
    "\n",
    "\n",
    "gen_kwargs = {\"max_length\": 4096, \"do_sample\": True, \"temperature\":0.7}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d419346e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.39s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "腊八节是中国的传统节日，这一天人们通常会喝腊八粥，庆祝丰收的同时也祈求来年的好运。如果你希望给你的同学送一个既契合腊八节的氛围又能鼓励他/她“放飞自我”的祝福，可以这样写：\n",
      "\n",
      "【腊八节特别祝福】\n",
      "\n",
      "亲爱的[同学名字]，\n",
      "\n",
      "在这个充满喜悦与期待的腊八佳节里，愿你像这碗香甜的腊八粥一样，既有家的味道，又有生活的甜蜜。新的一年，让我们一起勇敢地拥抱变化，大胆地追求梦想，放飞自我，做最真实的自己。愿你每天都有好心情，每一步都走得精彩纷呈！\n",
      "\n",
      "愿你在这个特别的日子里，不仅品尝到腊八粥的美味，还能感受到生活的无限可能。\n",
      "\n",
      "祝福满满，腊八快乐！\n",
      "\n",
      "---\n",
      "\n",
      "这样的祝福既体现了腊八节的文化特色，又包含了鼓励和祝福的话语，非常适合送给想要鼓励的同学。\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from peft import PeftModel\n",
    "\n",
    "mode_path = \"/home/temp/qwen/Qwen2___5-7B-Instruct\" # 这里改称你的 lora 输出对应 checkpoint 地址\n",
    "\n",
    "# 加载tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)\n",
    "\n",
    "# 加载模型\n",
    "model = AutoModelForCausalLM.from_pretrained(mode_path, device_map=\"auto\",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()\n",
    "\n",
    "prompt = \"祝同学腊八节,放飞自我风格\"\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"system\", \"content\": \"你现在是一个送祝福大师，帮我针对不同人和事情、节日送对应的祝福\"},{\"role\": \"user\", \"content\": prompt}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       ).to('cuda')\n",
    "\n",
    "\n",
    "gen_kwargs = {\"max_length\": 4096, \"do_sample\": True, \"temperature\":0.8}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e66fe7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
